ACKTR from scratch (low-level PyTorch) — CartPole-v1#

This notebook implements ACKTR-style policy optimization in low-level PyTorch:

  • Actor update uses a K-FAC preconditioner (approx. Fisher inverse) + trust-region clipping.

  • Critic is trained as a baseline (value function) with a simple first-order optimizer for stability.

We log training dynamics and visualize them with Plotly, including episodic reward progression.

Prereqs:

  • PyTorch

  • Gymnasium

  • Plotly

Theory reference: see 00_overview.ipynb in this folder.

Notebook roadmap#

  1. Setup + environment

  2. Actor/Critic networks

  3. Rollout collection + GAE

  4. K-FAC optimizer (Linear layers)

  5. Training loop (ACKTR update)

  6. Plotly diagnostics (reward + KL + losses)

  7. Stable-Baselines ACKTR reference + hyperparameters

import random
import time

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio

import gymnasium as gym

import torch
import torch.nn as nn
from torch.distributions import Categorical

pio.templates.default = 'plotly_white'
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

print('NumPy', np.__version__)
print('Pandas', pd.__version__)
print('Plotly', plotly.__version__)
print('Gymnasium', gym.__version__)
print('Torch', torch.__version__)
NumPy 1.26.2
Pandas 2.1.3
Plotly 6.5.2
Gymnasium 1.1.1
Torch 2.7.0+cu126
# --- Reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Keep the implementation CPU-friendly and deterministic-ish.
DEVICE = torch.device('cpu')
print('DEVICE', DEVICE)
DEVICE cpu
# --- Run configuration ---
FAST_RUN = True  # set False for a longer, smoother curve

# Environment
ENV_ID = 'CartPole-v1'

# Rollout / training
TOTAL_TIMESTEPS = 40_000 if FAST_RUN else 200_000
ROLLOUT_STEPS = 128

# Discounting / advantage
GAMMA = 0.99
GAE_LAMBDA = 0.95

# Loss weights
ENT_COEF = 0.00

# Critic optimizer
CRITIC_LR = 1e-3

# K-FAC / ACKTR knobs (actor)
ACTOR_LR = 0.10
KFAC_DAMPING = 0.03
KFAC_STATS_DECAY = 0.95
KFAC_CLIP = 0.01  # trust region / KL clip (see theory)
INVERSE_UPDATE_INTERVAL = 1

print('TOTAL_TIMESTEPS', TOTAL_TIMESTEPS)
TOTAL_TIMESTEPS 40000

1) Environment#

CartPole-v1 is a classic discrete-action benchmark:

  • state \(s \in \mathbb{R}^4\)

  • actions \(a \in \{0,1\}\)

  • reward \(r_t = 1\) per step until termination

It’s a good fit for a minimal ACKTR demonstration because the policy is a categorical distribution.

env = gym.make(ENV_ID)
obs_dim = int(env.observation_space.shape[0])
act_dim = int(env.action_space.n)

obs, _ = env.reset(seed=SEED)
print('obs_dim', obs_dim, 'act_dim', act_dim)
print('first obs', obs)
obs_dim 4 act_dim 2
first obs [ 0.0274 -0.0061  0.0359  0.0197]

2) Actor–critic parameterization#

We use two networks:

  • Actor: logits for a categorical policy \(\pi_\theta(a\mid s)\).

  • Critic: a value function baseline \(V_\phi(s)\).

The actor loss (policy gradient with entropy bonus) is:

\[ \mathcal{L}_{\pi}(\theta) + = -\mathbb{E}\left[\log \pi_\theta(a\mid s)\,\hat A(s,a)\right] + - \beta\,\mathbb{E}\left[\mathcal{H}(\pi_\theta(\cdot\mid s))\right]. +\]

The critic trains by regression to (bootstrapped) returns:

\[ \mathcal{L}_V(\phi) = \tfrac{1}{2}\,\mathbb{E}\left[(V_\phi(s) - \hat R)^2\right]. +\]
class Actor(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last, h))
            layers.append(nn.Tanh())
            last = h
        self.net = nn.Sequential(*layers)
        self.logits = nn.Linear(last, act_dim)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.logits(self.net(obs))


class Critic(nn.Module):
    def __init__(self, obs_dim: int, hidden_sizes=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last, h))
            layers.append(nn.Tanh())
            last = h
        self.net = nn.Sequential(*layers)
        self.v = nn.Linear(last, 1)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.v(self.net(obs)).squeeze(-1)


actor = Actor(obs_dim, act_dim).to(DEVICE)
critic = Critic(obs_dim).to(DEVICE)

critic_optim = torch.optim.Adam(critic.parameters(), lr=CRITIC_LR)

print('actor params', sum(p.numel() for p in actor.parameters()))
print('critic params', sum(p.numel() for p in critic.parameters()))
actor params 4610
critic params 4545

3) Rollouts + GAE#

We collect on-policy rollouts of length \(T\) and compute generalized advantage estimation (GAE):

\[ \delta_t = r_t + \gamma (1-d_t) V(s_{t+1}) - V(s_t) +\]
\[ \hat A_t = \sum_{l=0}^{\infty} (\gamma\lambda)^l\,\delta_{t+l} +\]

with \(d_t \in \{0,1\}\) indicating episode termination.

def compute_gae(rewards, values, dones, last_value, gamma: float, lam: float):
    """NumPy GAE for a single rollout segment."""
    T = len(rewards)
    advantages = np.zeros(T, dtype=np.float32)
    gae = 0.0
    for t in reversed(range(T)):
        next_value = last_value if t == T - 1 else values[t + 1]
        next_nonterminal = 1.0 - dones[t]
        delta = rewards[t] + gamma * next_value * next_nonterminal - values[t]
        gae = delta + gamma * lam * next_nonterminal * gae
        advantages[t] = gae
    returns = advantages + values
    return advantages, returns

4) K-FAC optimizer (Linear layers)#

ACKTR replaces a vanilla gradient step with a (preconditioned) natural gradient step.

For the policy parameters \(\theta\), the natural gradient direction is:

\[ \Delta\theta = F^{-1} g,\qquad g = \nabla_\theta J(\theta). +\]

K-FAC approximates \(F\) block-wise per layer using Kronecker factors:

\[ F_{\ell} \approx G_{\ell} \otimes A_{\ell}, +\quad A_{\ell}=\mathbb{E}[a a^\top],\quad G_{\ell}=\mathbb{E}[g g^\top]. +\]

For a linear layer, this yields the matrix-form update (with damping):

\[ \Delta W_{\ell} \approx G_{\ell}^{-1}\,\nabla_{W_{\ell}}\mathcal{L}\,A_{\ell}^{-1}. +\]

We also apply a trust-region-style scaling so the policy does not change too much.

class KFACOptimizer:
    """Minimal K-FAC for nn.Linear modules (actor only).

    - Collects factor stats (A,G) via forward/backward hooks on a Fisher-like loss.
    - Preconditions parameter gradients with G^{-1} @ grad @ A^{-1}.
    - Scales the step using a trust-region clip.
    """

    def __init__(
        self,
        model: nn.Module,
        lr: float,
        damping: float,
        stats_decay: float,
        kfac_clip: float,
        inverse_update_interval: int = 1,
    ):
        self.model = model
        self.lr = float(lr)
        self.damping = float(damping)
        self.stats_decay = float(stats_decay)
        self.kfac_clip = float(kfac_clip)
        self.inverse_update_interval = int(inverse_update_interval)

        self._collect_stats = False
        self._step = 0

        self.modules = []
        self.state = {}

        for module in self.model.modules():
            if isinstance(module, nn.Linear):
                self.modules.append(module)
                self.state[module] = {
                    'A': None,
                    'G': None,
                    'A_inv': None,
                    'G_inv': None,
                }
                module.register_forward_hook(self._forward_hook)
                module.register_full_backward_hook(self._backward_hook)

    def set_collect_stats(self, collect: bool):
        self._collect_stats = bool(collect)

    def _forward_hook(self, module, inputs, output):
        if not self._collect_stats:
            return
        module._kfac_input = inputs[0].detach()

    def _backward_hook(self, module, grad_input, grad_output):
        if not self._collect_stats:
            return
        module._kfac_grad_output = grad_output[0].detach()

    @torch.no_grad()
    def update_stats(self):
        for module in self.modules:
            if not hasattr(module, '_kfac_input') or not hasattr(module, '_kfac_grad_output'):
                continue

            a = module._kfac_input
            g = module._kfac_grad_output

            if a.dim() != 2 or g.dim() != 2:
                continue

            batch = a.shape[0]
            ones = torch.ones(batch, 1, device=a.device, dtype=a.dtype)
            a_aug = torch.cat([a, ones], dim=1)

            A_new = (a_aug.t() @ a_aug) / batch
            G_new = (g.t() @ g) / batch

            st = self.state[module]
            if st['A'] is None:
                st['A'] = A_new
                st['G'] = G_new
            else:
                d = self.stats_decay
                st['A'] = d * st['A'] + (1 - d) * A_new
                st['G'] = d * st['G'] + (1 - d) * G_new

        self._step += 1
        if self._step % self.inverse_update_interval == 0:
            self._update_inverses()

    @torch.no_grad()
    def _update_inverses(self):
        for module in self.modules:
            st = self.state[module]
            if st['A'] is None or st['G'] is None:
                continue

            A = st['A'] + self.damping * torch.eye(st['A'].shape[0], device=st['A'].device, dtype=st['A'].dtype)
            G = st['G'] + self.damping * torch.eye(st['G'].shape[0], device=st['G'].device, dtype=st['G'].dtype)

            st['A_inv'] = torch.linalg.inv(A)
            st['G_inv'] = torch.linalg.inv(G)

    @torch.no_grad()
    def step(self):
        eps = 1e-8

        shs = 0.0  # proxy for g^T F^{-1} g
        updates = []

        for module in self.modules:
            st = self.state[module]
            if st['A_inv'] is None or st['G_inv'] is None:
                continue

            if module.weight.grad is None:
                continue
            if module.bias is None or module.bias.grad is None:
                continue

            grad_w = module.weight.grad
            grad_b = module.bias.grad
            grad_wb = torch.cat([grad_w, grad_b.unsqueeze(1)], dim=1)

            nat_wb = st['G_inv'] @ grad_wb @ st['A_inv']
            nat_w = nat_wb[:, :-1]
            nat_b = nat_wb[:, -1]

            shs += float((grad_w * nat_w).sum().item() + (grad_b * nat_b).sum().item())
            updates.append((module.weight, nat_w))
            updates.append((module.bias, nat_b))

        # Trust-region / KL clip: only scale down.
        # (Theory: predicted KL \approx 0.5 * alpha^2 * g^T F^{-1} g)
        nu = 1.0
        if shs > 0:
            predicted_kl = 0.5 * shs
            nu = float(min(1.0, np.sqrt(self.kfac_clip / (predicted_kl + eps))))
        else:
            predicted_kl = 0.0

        for param, nat_grad in updates:
            param.add_(nat_grad, alpha=-self.lr * nu)

        return {
            'shs': shs,
            'predicted_kl': predicted_kl,
            'nu': nu,
        }


actor_kfac = KFACOptimizer(
    actor,
    lr=ACTOR_LR,
    damping=KFAC_DAMPING,
    stats_decay=KFAC_STATS_DECAY,
    kfac_clip=KFAC_CLIP,
    inverse_update_interval=INVERSE_UPDATE_INTERVAL,
)

5) Training loop (ACKTR update)#

Each update:

  1. Collect \(T\) on-policy transitions.

  2. Compute \(\hat A_t\) and \(\hat R_t\) with GAE.

  3. Update critic by minimizing \(\mathcal{L}_V\).

  4. For the actor:

    • build a Fisher-like loss (to collect K-FAC stats)

    • backprop that loss to update \(A\) and \(G\)

    • backprop the policy loss and take a K-FAC-preconditioned step

We log:

  • episodic returns

  • actor loss, critic loss, entropy

  • estimated KL (before/after update)

  • trust-region scale factor \(\nu\)

num_updates = TOTAL_TIMESTEPS // ROLLOUT_STEPS
print('num_updates', num_updates)

obs, _ = env.reset(seed=SEED)

episode_return = 0.0
episode_len = 0
episode_returns = []
episode_lengths = []

logs = []
start = time.time()

for update in range(1, num_updates + 1):
    # --- Rollout buffers ---
    obs_buf = np.zeros((ROLLOUT_STEPS, obs_dim), dtype=np.float32)
    act_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.int64)
    rew_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)
    done_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)
    val_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)

    for t in range(ROLLOUT_STEPS):
        obs_buf[t] = obs

        obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        with torch.no_grad():
            logits = actor(obs_t)
            dist = Categorical(logits=logits)
            action = dist.sample()
            value = critic(obs_t)

        next_obs, reward, terminated, truncated, _ = env.step(int(action.item()))
        done = bool(terminated or truncated)

        act_buf[t] = int(action.item())
        rew_buf[t] = float(reward)
        done_buf[t] = float(done)
        val_buf[t] = float(value.item())

        episode_return += float(reward)
        episode_len += 1

        obs = next_obs
        if done:
            episode_returns.append(episode_return)
            episode_lengths.append(episode_len)
            episode_return = 0.0
            episode_len = 0
            obs, _ = env.reset()

    with torch.no_grad():
        if done_buf[-1] == 1.0:
            last_value = 0.0
        else:
            last_obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            last_value = float(critic(last_obs_t).item())

    advantages, returns = compute_gae(
        rewards=rew_buf,
        values=val_buf,
        dones=done_buf,
        last_value=last_value,
        gamma=GAMMA,
        lam=GAE_LAMBDA,
    )

    obs_batch = torch.tensor(obs_buf, dtype=torch.float32, device=DEVICE)
    act_batch = torch.tensor(act_buf, dtype=torch.int64, device=DEVICE)
    adv_batch = torch.tensor(advantages, dtype=torch.float32, device=DEVICE)
    ret_batch = torch.tensor(returns, dtype=torch.float32, device=DEVICE)

    adv_batch = (adv_batch - adv_batch.mean()) / (adv_batch.std() + 1e-8)

    # --- Critic update (first-order) ---
    critic_optim.zero_grad(set_to_none=True)
    v_pred = critic(obs_batch)
    critic_loss = 0.5 * (ret_batch - v_pred).pow(2).mean()
    critic_loss.backward()
    critic_optim.step()

    # --- Actor update (ACKTR-style via K-FAC) ---
    actor_kfac.set_collect_stats(True)
    logits_old = actor(obs_batch).detach()
    dist_old = Categorical(logits=logits_old)

    logits = actor(obs_batch)
    dist = Categorical(logits=logits)
    logp = dist.log_prob(act_batch)
    entropy = dist.entropy().mean()

    actor_loss = -(logp * adv_batch.detach()).mean() - ENT_COEF * entropy

    # Fisher-like loss: E[-log pi(a|s)]
    fisher_loss = -logp.mean()

    actor.zero_grad(set_to_none=True)
    fisher_loss.backward(retain_graph=True)
    actor_kfac.set_collect_stats(False)
    actor_kfac.update_stats()

    actor.zero_grad(set_to_none=True)
    actor_loss.backward()
    step_info = actor_kfac.step()

    with torch.no_grad():
        logits_new = actor(obs_batch)
        dist_new = Categorical(logits=logits_new)
        approx_kl = torch.distributions.kl_divergence(dist_old, dist_new).mean().item()

    logs.append(
        {
            'update': update,
            'timesteps': update * ROLLOUT_STEPS,
            'episodes': len(episode_returns),
            'actor_loss': float(actor_loss.item()),
            'critic_loss': float(critic_loss.item()),
            'entropy': float(entropy.item()),
            'approx_kl': float(approx_kl),
            **step_info,
        }
    )

    if update % 25 == 0:
        recent = episode_returns[-20:]
        mean_20 = float(np.mean(recent)) if recent else float('nan')
        elapsed = time.time() - start
        print(
            f'update {update:4d}/{num_updates} | episodes {len(episode_returns):4d} '
            f'| mean_return_20 {mean_20:7.2f} | kl {approx_kl:9.2e} | nu {step_info["nu"]:7.3f} '
            f'| elapsed {elapsed:6.1f}s'
        )

env.close()
num_updates 312
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:824: UserWarning:

CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
update   25/312 | episodes   56 | mean_return_20   93.55 | kl  4.40e-03 | nu   0.062 | elapsed    0.6s
update   50/312 | episodes   69 | mean_return_20  180.05 | kl  3.59e-02 | nu   0.066 | elapsed    1.1s
update   75/312 | episodes   87 | mean_return_20  209.00 | kl  5.94e-02 | nu   0.044 | elapsed    1.7s
update  100/312 | episodes  104 | mean_return_20  174.20 | kl  3.60e-02 | nu   0.036 | elapsed    2.3s
update  125/312 | episodes  121 | mean_return_20  178.80 | kl  3.47e-02 | nu   0.030 | elapsed    2.8s
update  150/312 | episodes  140 | mean_return_20  159.80 | kl  3.95e-02 | nu   0.044 | elapsed    3.4s
update  175/312 | episodes  159 | mean_return_20  172.60 | kl  1.87e-02 | nu   0.040 | elapsed    4.0s
update  200/312 | episodes  173 | mean_return_20  216.50 | kl  6.78e-02 | nu   0.019 | elapsed    4.6s
update  225/312 | episodes  192 | mean_return_20  156.50 | kl  9.28e-02 | nu   0.197 | elapsed    5.2s
update  250/312 | episodes  211 | mean_return_20  186.60 | kl  3.06e-02 | nu   0.049 | elapsed    5.7s
update  275/312 | episodes  245 | mean_return_20   90.65 | kl  1.04e-02 | nu   0.063 | elapsed    6.4s
update  300/312 | episodes  273 | mean_return_20  118.30 | kl  2.05e-02 | nu   0.048 | elapsed    6.9s

6) Plotly: learning dynamics#

We visualize:

  • episodic reward progression (raw + smoothed)

  • estimated KL per update

  • actor/critic losses

  • trust-region scaling factor \(\nu\)

df_logs = pd.DataFrame(logs)
df_eps = pd.DataFrame({'episode': np.arange(1, len(episode_returns) + 1), 'return': episode_returns})
df_eps['return_smooth'] = df_eps['return'].rolling(window=20, min_periods=1).mean()

fig = go.Figure()
fig.add_trace(go.Scatter(x=df_eps['episode'], y=df_eps['return'], mode='lines', name='return', line=dict(width=1)))
fig.add_trace(go.Scatter(x=df_eps['episode'], y=df_eps['return_smooth'], mode='lines', name='return (20-ep mean)', line=dict(width=3)))
fig.update_layout(
    title='Episodic reward progression (CartPole-v1)',
    xaxis_title='Episode',
    yaxis_title='Return',
    height=420,
)
fig.show()

fig2 = px.line(df_logs, x='timesteps', y=['approx_kl', 'predicted_kl'], title='KL diagnostics per update')
fig2.update_layout(height=380)
fig2.show()

fig3 = px.line(df_logs, x='timesteps', y=['actor_loss', 'critic_loss'], title='Losses per update')
fig3.update_layout(height=380)
fig3.show()

fig4 = px.line(df_logs, x='timesteps', y=['nu'], title='Trust-region scaling (nu)')
fig4.update_layout(height=320)
fig4.show()

7) Stable-Baselines ACKTR (reference)#

We’ll include a reference snippet for the (TensorFlow-based) Stable-Baselines implementation of ACKTR, plus an explanation of its key hyperparameters.

This section is reference only — the implementation above is the main deliverable.

Stable-Baselines usage (snippet)#

# pip install stable-baselines==2.*  (TensorFlow 1.x based)
from stable_baselines import ACKTR

model = ACKTR(
    policy='MlpPolicy',
    env='CartPole-v1',
    n_steps=20,
    gamma=0.99,
    ent_coef=0.01,
    vf_coef=0.25,
    vf_fisher_coef=1.0,
    learning_rate=0.25,
    max_grad_norm=0.5,
    kfac_clip=0.001,
    lr_schedule='linear',
    kfac_update=1,
    gae_lambda=None,
    verbose=1,
)

model.learn(total_timesteps=200_000)

Hyperparameters (Stable-Baselines) explained#

Stable-Baselines (“v2”, TensorFlow 1.x) includes an ACKTR implementation (see stable_baselines/acktr/acktr.py). The constructor signature is:

ACKTR(
  policy,
  env,
  gamma=0.99,
  n_steps=20,
  ent_coef=0.01,
  vf_coef=0.25,
  vf_fisher_coef=1.0,
  learning_rate=0.25,
  max_grad_norm=0.5,
  kfac_clip=0.001,
  lr_schedule='linear',
  async_eigen_decomp=False,
  kfac_update=1,
  gae_lambda=None,
  policy_kwargs=None,
  seed=None,
  n_cpu_tf_sess=1,
  # + logging/boilerplate args
)

Core RL knobs

  • gamma: discount factor.

  • n_steps: rollout length per environment before each update.

  • gae_lambda: if not None, Stable-Baselines computes GAE with parameter \(\lambda\); if None, it uses the classic advantage (no GAE).

  • ent_coef: entropy bonus weight (encourages exploration).

  • vf_coef: value loss weight in the joint loss.

ACKTR / K-FAC + trust region knobs

  • kfac_clip: KL-based clip used inside the K-FAC optimizer (trust-region-like safeguard; called clip_kl in the underlying optimizer).

  • vf_fisher_coef: weight on the value-function Fisher loss. In the Stable-Baselines code, the value Fisher is constructed by adding noise to the value output and backpropagating a Gaussian negative log-likelihood; this lets K-FAC build curvature stats for the critic.

  • learning_rate: the step size used by the K-FAC optimizer (and scheduled by lr_schedule).

  • lr_schedule: learning-rate schedule string ('linear', 'constant', 'double_linear_con', 'middle_drop', 'double_middle_drop').

  • kfac_update: update frequency for K-FAC statistics / eigen decompositions.

  • async_eigen_decomp: compute eigen decompositions asynchronously (speed/throughput trade-off).

  • max_grad_norm: global gradient clipping.

Practical / reproducibility knobs

  • policy: policy network type (e.g. MlpPolicy, CnnPolicy, CnnLstmPolicy).

  • env: Gym env instance or env id string.

  • policy_kwargs: extra arguments forwarded to the policy.

  • seed: seeds python/NumPy/TensorFlow RNGs.

  • n_cpu_tf_sess: TensorFlow thread count (for determinism, set this to 1).

Note: Stable-Baselines wires ACKTR into kfac.KfacOptimizer(...) with additional internal defaults (e.g. momentum=0.9, epsilon=0.01, stats_decay=0.99, cold_iter=10).